Split
将一个张量(Tensor)沿着指定的轴(axis)拆分为多个子张量。子张量在指定轴上的大小由 split_sizes 数组决定。
\[\text{input.shape} = [d_0, d_1, \dots, d_{axis}, \dots, d_{n-1}]\]
\[\text{对于第 } j \text{ 个输出: } output_j\text{.shape} = [d_0, d_1, \dots, split\_sizes[j], \dots, d_{n-1}]\]
- 输入:
input - 输入数据起始地址。
axis - 指定拆分的维度轴。
input_shape - 输入张量的形状数组地址。
input_ndim - 输入张量的维度数。
num_split - 拆分出的子张量个数。
split_sizes - 一个数组,包含每个子张量在拆分轴上的长度。
core_mask(int, 可选) - 核掩码(仅适用于共享存储版本)。
- 输出:
outputs - 指针数组地址,其中每个元素指向一个子张量的存储地址。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128
MT7004 支持 fp16, fp32, int16, int32, cplx64
split_sizes的元素之和必须等于输入张量在axis维度的长度。对于复数类型(cplx64 / cplx128),拆分逻辑与实数一致,但需注意地址偏移按复数对计算。
共享存储版本:
-
void i8_split_s(int8_t *input, int8_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void i16_split_s(int16_t *input, int16_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void i32_split_s(int32_t *input, int32_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void hp_split_s(half *input, half *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void fp_split_s(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void dp_split_s(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void c64_split_s(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
-
void c128_split_s(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
C调用示例:
1//FT78NE示例(共享存储) 2#include <stdio.h> 3#include "78NE/utils.h" 4 5int main(int argc, char* argv[]) { 6 float *input = (float *)0xA0000000; 7 float *out0 = (float *)0xB0000000; 8 float *out1 = (float *)0xB1000000; 9 float *outputs[] = { out0, out1 }; 10 int input_shape[] = { 2, 10, 4 }; 11 int split_sizes[] = { 6, 4 }; 12 int axis = 1; 13 int input_ndim = 3; 14 int num_split = 2; 15 int core_mask = 0b1011; 16 17 fp_split_s(input, outputs, axis, input_shape, input_ndim, num_split, split_sizes, core_mask); 18 return 0; 19}
私有存储版本:
-
void i8_split_p(int8_t *input, int8_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void i16_split_p(int16_t *input, int16_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void i32_split_p(int32_t *input, int32_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void hp_split_p(half *input, half *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void fp_split_p(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void dp_split_p(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void c64_split_p(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
-
void c128_split_p(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
C调用示例:
1//MT7004 示例 2#include <stdio.h> 3 4int main(int argc, char* argv[]) { 5 float *input = (float *)0x10000000; // 私有存储地址 6 float *out0 = (float *)0x10010000; 7 float *out1 = (float *)0x10020000; 8 float *outputs[] = { out0, out1 }; 9 int input_shape[] = { 20, 10 }; 10 int split_sizes[] = { 10, 10 }; 11 int axis = 0; 12 fp_split_p(input, outputs, axis, input_shape, 2, 2, split_sizes); 13 return 0; 14}